#!/bin/bash

# Copyright (c) 2018-2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


###########################################################################
# Pytorch multi-node jobs require a bunch of non-standard envvars to be set.
# This script derives and sets reasonable value for those variables from SLURM.
###########################################################################

set -euo pipefail

# only if we're in a pytorch container:
if [[ "${PYTORCH_VERSION-}" ]]; then

    # MASTER_ADDR (or MLPERF_SLURM_FIRSTNODE) should be set from host scripts.
    # Setting it requires `scontrol` which is typically unavailable from inside
    # containers.  a typical way to set it from host would be something like:
    # $(scontrol show hostnames "${SLURM_JOB_NODELIST-}" | head -n1) if neither
    # is set we fall back to 127.0.0.1 which will help for single-node jobs,
    # but will fail for multinode.
    export MASTER_ADDR="${MASTER_ADDR:-${MLPERF_SLURM_FIRSTNODE:-127.0.0.1}}"

    # MASTER_PORT is needed for static rendezvous.  Unfortunately
    # torch.distributed TCPStore doesn't handle races for dynamic port
    # assignment in the ephemeral port range gracefully, so to mitigate we
    # choose a static port number that is neither in the ephemeral port range
    # defined by IANA (49152-65535), nor in the ephemeral port range defined by
    # default in Ubuntu (32768-60999).  That leaves the range 1024-32767.
    # https://github.com/pytorch/pytorch/blob/main/torch/distributed/run.py
    # uses 29500, (and the torch distributed docs use 29501), so we also use
    # 29500.
    # https://www.iana.org/assignments/service-names-port-numbers/service-names-port-numbers.txt
    # shows that 29500 is not currently assigned to anyone else, and since it
    # is also not ephemeral it should always be free.  If your org has some
    # other service assigned to this particular port, then find another port in
    # 1024-32767 that is not assigned in the IANA list and manually override
    # MASTER_PORT to that number.
    export MASTER_PORT="${MASTER_PORT:-29500}"

    # WORLD_SIZE and RANK are required for
    # torch.distributed.init_process_group(backend='nccl',
    # init_method='env://') we can derive appropriate values either from Slurm
    # variables (srun) or from OMPI_COMM variables (mpirun) if we're not
    # running under slurm or mpirun, we fall back to values appropriate for
    # single gpu
    export WORLD_SIZE="${WORLD_SIZE:-${SLURM_NTASKS:-${OMPI_COMM_WORLD_SIZE:-1}}}"
    export RANK="${RANK:-${SLURM_PROCID:-${OMPI_COMM_WORLD_RANK:-0}}}"

    # by convention we use LOCAL_RANK for torch.cuda.set_device()
    export LOCAL_RANK="${LOCAL_RANK:-${SLURM_LOCALID:-${OMPI_COMM_WORLD_LOCAL_RANK:-0}}}"

    # Match the behavior of torch.distributed.run:
    # https://github.com/pytorch/pytorch/blob/v1.9.0/torch/distributed/run.py#L521-L532
    if [[ "${SLURM_NTASKS_PER_NODE:-1}" -gt 1 ]] && [[ ! "${OMP_NUM_THREADS-}" ]]; then
	export OMP_NUM_THREADS=1
    fi

    if [[ "${NV_MLPERF_DEBUG:-}" ]] && [[ "${RANK}" == "0" ]]; then
	echo "slurm2pytorch: MASTER_ADDR=${MASTER_ADDR} MASTER_PORT=${MASTER_PORT} WORLD_SIZE=${WORLD_SIZE}"
    fi
fi # end if [[ "${PYTORCH_VERSION-}" ]]

#############################################################################
# Exec the child script with the new variables
#############################################################################

exec "${@}"
